import numpy as np
import torch
from torch.utils import data
import torchvision
from torch.utils.data import Dataset

import os
import glob
import numpy as np
import open3d as o3d

from data.utils import generate_rand_rotm, generate_rand_trans, apply_transform

def read_pcd(filename, npoints=None, voxel_size=None):

    pcd = o3d.io.read_point_cloud(filename)

    if voxel_size is not None:
        pcd = pcd.voxel_down_sample(voxel_size=voxel_size)
        scan = np.asarray(pcd.points)
    
    if npoints is None:
        scan = np.asarray(pcd.points)
        return scan.astype('float32')
    
    N = scan.shape[0]
    if N >= npoints:
        sample_idx = np.random.choice(N, npoints, replace=False)
    else:
        sample_idx = np.concatenate((np.arange(N), np.random.choice(N, npoints-N, replace=True)), axis=-1)
    
    scan = scan[sample_idx, :].astype('float32')
    return scan

class ApolloDataset(Dataset):
    '''
    Params:
        root
        seqs
        npoints
        voxel_size
        data_list
        augment
    '''
    def __init__(self, root, seqs, npoints, voxel_size, data_list, augment=0.0):
        super(ApolloDataset, self).__init__()

        self.root = root
        self.seqs = seqs
        self.npoints = npoints
        self.voxel_size = voxel_size
        self.augment = augment
        self.data_list = data_list
        self.data_len_sequence = []
        self.dataset = self.make_dataset()
        self.randg = np.random.RandomState()
    
    def make_dataset(self):
        last_row = np.zeros((1,4), dtype=np.float32)
        last_row[:,3] = 1.0
        dataset = []

        for seq in self.seqs:
            seq_count = 0
            fn_pair_poses = os.path.join(self.data_list, seq + '.txt')
            last_timestamp = -1

            with open(fn_pair_poses, 'r') as f:
                lines = f.readlines()
                for line in lines:
                    data_dict = {}
                    line = line.strip(' \n').split(' ')
                    src_fn = os.path.join(self.root, line[0])
                    dst_fn = os.path.join(self.root, line[1])
                    timestamp = self.extract_timestamps(src_fn)
                    values = []
                    for i in range(2, len(line)):
                        values.append(float(line[i]))
                    values = np.array(values).astype(np.float32)
                    rela_pose = values.reshape(3,4)
                    rela_pose = np.concatenate([rela_pose, last_row], axis = 0)
                    data_dict['points1'] = src_fn
                    data_dict['points2'] = dst_fn
                    data_dict['Tr'] = rela_pose
                    if last_timestamp == -1:
                        data_dict["seq_index"] = seq_count 
                        last_timestamp = timestamp
                        seq_count += 1
                    else:
                        if abs(timestamp - last_timestamp) > 2:
                            self.data_len_sequence.append(seq_count)
                            seq_count = 0
                            data_dict["seq_index"] = seq_count
                            last_timestamp = timestamp
                            seq_count += 1
                        else:
                            data_dict["seq_index"] = seq_count
                            last_timestamp = timestamp
                            seq_count += 1
                    dataset.append(data_dict)
        
        return dataset
    
    def extract_timestamps(self, filepath):
        filename = os.path.basename(filepath)
        timestamp = filename.split("/")[-1].split(".")[0]
        return int(timestamp)

    def __getitem__(self, index):
        data_dict = self.dataset[index]
        src_points = read_pcd(data_dict['points1'], self.npoints, self.voxel_size)
        dst_points = read_pcd(data_dict['points2'], self.npoints, self.voxel_size)
        Tr = data_dict['Tr']
        
        # Random rotation augmentation (Only for training feature extraction)
        if np.random.rand() < self.augment:
            aug_T = np.zeros((4,4), dtype=np.float32)
            aug_T[3,3] = 1.0
            rand_rotm = generate_rand_rotm(0.0, 0.0, 45.0)
            aug_T[:3,:3] = rand_rotm
            src_points = apply_transform(src_points, aug_T)
            Tr = Tr.dot(np.linalg.inv(aug_T))
        
        src_points = torch.from_numpy(src_points)
        dst_points = torch.from_numpy(dst_points)
        Tr = torch.from_numpy(Tr)
        R = Tr[:3,:3]
        t = Tr[:3,3]
        for seq_len in self.data_len_sequence:
            if index < seq_len:
                seq_len_current = seq_len
                break
            index -= seq_len
        return src_points, dst_points, R, t, data_dict["seq_index"], seq_len_current
    
    def reset_seed(self, seed=0):
        self.randg.seed(seed)
    
    def __len__(self):
        return len(self.dataset)